9.3 - Code - Implementing an Action Mask
This section provides the complete, updated code required to implement action masking in our DRL_RM_Bot
environment. This is a critical upgrade that makes training the generalized agent feasible by focusing its exploration on valid actions only.
Implementation Workflow
- Step 1: Install the
sb3-contrib
library. - Step 2: Update
drl_rm_bot.py
to support the newDict
observation space and generate the mask. - Step 3: Update
train.py
to use theMaskablePPO
algorithm.
Step 1 - Install sb3-contrib
In your activated virtual environment, install the companion library:
pip install sb3-contrib
Step 2 - The Code: Updated drl_rm_bot.py
with Masking Logic
This is the modified version of our DRL-RM environment. The most significant changes are in the DRL_RM_Env
class, which now defines a Dict
space with the required keys ("obs"
and "action_mask"
), and the DRL_RM_Bot
class, which has a new method to generate the mask.
drl_rm_bot.py
import numpy as np
import gymnasium as gym
from gymnasium.spaces import Box, Dict, MultiDiscrete
import queue
from burnysc2.bot_ai import BotAI
from burnysc2.race import Terran
from sc2.units import Units
from sc2_gym_env import SC2GymEnv
# --- Environment Constants ---
MAX_UNITS = 50
NUM_UNIT_FEATURES = 5
NUM_ABILITIES = 2 # 0=Move, 1=Attack
class DRL_RM_Bot(BotAI):
"""
The DRL-RM bot, now updated to generate an action mask.
"""
# NOTE: _encode_state() and _decode_and_execute_action() are kept from the previous section.
# For brevity, only the new/changed methods are shown here.
def __init__(self, action_queue, obs_queue):
super().__init__()
self.action_queue = action_queue
self.obs_queue = obs_queue
self.race = Terran
def _get_action_mask(self) -> np.ndarray:
"""
Calculates a flat boolean mask for all possible actions.
An action is valid if its actor and target indices are within bounds
of the current number of units.
"""
my_units_len = len(self.units)
all_units_len = len(self.all_units)
# Create a 2D mask for valid actor-target pairs.
# A cell (i, j) is True if actor i and target j are valid units.
actor_mask = np.arange(MAX_UNITS) < my_units_len
target_mask = np.arange(MAX_UNITS * 2) < all_units_len
valid_pairs = np.logical_and(actor_mask[:, None], target_mask[None, :])
# Expand dimensions to (MAX_UNITS, 1, MAX_UNITS * 2).
expanded_mask = valid_pairs[:, None, :]
# Tile along the ability dimension, assuming all abilities are potentially valid.
# This creates the final 3D mask of shape (50, 2, 100).
final_mask = np.tile(expanded_mask, (1, NUM_ABILITIES, 1))
# Flatten the 3D mask into a 1D vector for the agent.
return final_mask.flatten()
async def on_step(self, iteration: int):
# The main loop now follows a cleaner Observe -> Act cycle.
if iteration % 8 == 0:
# 1. OBSERVE: Get the current state and generate the observation and mask.
observation = self._encode_state()
action_mask = self._get_action_mask()
# The observation is now a dictionary with "obs" and "action_mask" keys.
obs_dict = {
"obs": observation,
"action_mask": action_mask
}
terminated = self.townhalls.amount == 0
# 2. SEND OBSERVATION: Send the data to the agent and wait for an action.
self.obs_queue.put((obs_dict, 0.1, terminated, False, {}))
if terminated:
await self.client.leave()
return
# 3. ACT: Get the (now guaranteed valid) action from the agent and execute it.
try:
action = self.action_queue.get_nowait()
await self._decode_and_execute_action(action)
except queue.Empty:
pass
class DRL_RM_Env(SC2GymEnv):
"""The Gymnasium Wrapper, updated to use a Dict observation space."""
def __init__(self):
super().__init__(bot_class=DRL_RM_Bot, map_name="AcropolisLE")
self.action_space = MultiDiscrete([MAX_UNITS, NUM_ABILITIES, MAX_UNITS * 2])
# The observation space must now be a Dict space for MaskablePPO.
self.observation_space = Dict({
# The key "obs" holds our original unit feature matrix.
"obs": Box(low=0, high=1, shape=(MAX_UNITS, NUM_UNIT_FEATURES), dtype=np.float32),
# The key "action_mask" holds the boolean mask.
"action_mask": Box(low=0, high=1, shape=(self.action_space.nvec.prod(),), dtype=bool)
})
Step 3 - The Code: Updated train.py
Your training script requires a single, simple change: importing and using MaskablePPO
. No other changes are needed, as MaskablePPO
is designed to automatically find the "obs"
and "action_mask"
keys in the environment's observation space.
# train.py
import multiprocessing as mp
# Import from sb3_contrib instead of stable_baselines3
from sb3_contrib import MaskablePPO
from drl_rm_bot import DRL_RM_Env
def main():
env = DRL_RM_Env()
# Use the MaskablePPO class
model = MaskablePPO("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=200_000)
model.save("ppo_masked_drl_rm")
env.close()
if __name__ == '__main__':
mp.freeze_support()
main()
With these modifications, your framework is now equipped with action masking, a crucial technique for making the complex DRL-RM agent trainable.